{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/AviSoori1x/makeMoE/blob/main/makeMoE_from_Scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "8e9b80fc-12cf-41a9-a0de-354f678b412b",
          "showTitle": false,
          "title": ""
        },
        "id": "90vgVgmDkRJQ"
      },
      "source": [
        "#### Sparse mixture of experts language model from scratch inspired by (and largely based on) Andrej Karpathy's makemore (https://github.com/karpathy/makemore) :)\n",
        "\n",
        "This is a from scratch implementation of a sparse mixture of experts language model. This is inspired by and largely based on Andrej Karpathy's project 'makemore' and borrows most of the re-usable components from that implementation. Just like makemore, makeMoE is also an autoregressive character-level language model but uses the aforementioned sparse mixture of experts architecture.\n",
        "\n",
        "Just like makemore, pytorch is the only requirement (so I hope the from scratch claim is justified).\n",
        "\n",
        "Significant Changes from the makemore architecture\n",
        "\n",
        "- Sparse mixture of experts instead of the solitary feed forward neural net.\n",
        "- Top-k gating and noisy top-k gating implementations.\n",
        "- initialization - Kaiming He initialization is used here but the point of this notebook is to be hackable so you can swap in Xavier Glorot etc. and take it for a spin.\n",
        "\n",
        "Unchanged from makemore\n",
        "- The dataset, preprocessing (tokenization), and the language modeling task Andrej chose originally - generate Shakespeare-like text\n",
        "- Casusal self attention implementation\n",
        "- Training loop\n",
        "- Inference logic\n",
        "\n",
        "Publications heavily referenced for this implementation:\n",
        "- Mixtral of experts: https://arxiv.org/pdf/2401.04088.pdf\n",
        "- Outrageosly Large Neural Networks: The Sparsely-Gated Mixture-Of-Experts layer: https://arxiv.org/pdf/1701.06538.pdf\n",
        "\n",
        "\n",
        "This notebook walks through the intuition for the entire model architecture and how everything comes together\n",
        "\n",
        "The code was entirely developed on Databricks using a single A100 for compute. If you're running this on Databricks, you can scale this on an arbitrarily large GPU cluster with no issues in the cloud provider of your choice\n",
        "\n",
        "I chose to use mlflow (which comes pre-installed in Databricks. You can pip install easily elsewhere) as I find it helpful to track and log all the metrics necessary. This is entirely optional.\n",
        "\n",
        "Please note that the implementation emphasizes readability and hackability vs performance, so there are many ways in which you could improve this. Please try and let me know\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "2f4a58a8-bd4c-40de-a4a9-95457842db0b",
          "showTitle": false,
          "title": ""
        },
        "id": "hywLNfb0kRJT"
      },
      "source": [
        "![mixture of experts overview](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/moe.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "35b3daa3-3b3b-47af-b3e7-be95878f9e06",
          "showTitle": false,
          "title": ""
        },
        "id": "RQAnP6_RkRJU"
      },
      "outputs": [],
      "source": [
        "#Using mlflow is entirely optional. I personally like to use MLFlow to track and log everything. If you're using Databricks, it comes pre-installed.\n",
        "%pip install mlflow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "5e1a3e38-8717-42ec-9bbc-71d3712c1c68",
          "showTitle": false,
          "title": ""
        },
        "id": "V521QQ_qkRJV"
      },
      "outputs": [],
      "source": [
        "#Import the necessary packages and set seed for reproducibility. For this notebook, pytorch is all you need\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "torch.manual_seed(42)\n",
        "#Optional\n",
        "import mlflow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "faf99ef2-39bb-46fc-b772-05d6d0482bbc",
          "showTitle": false,
          "title": ""
        },
        "id": "-4r_QNRRkRJV"
      },
      "source": [
        "Next few sections, downloading the data, preprocessing it and self attention are directly from makemore. I have elaborated a little on self attention and added visual aids to understand the process a bit better."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "45143d84-28c7-463d-9fb5-e21122842600",
          "showTitle": false,
          "title": ""
        },
        "id": "2GhDw0yWkRJV",
        "outputId": "22b79649-9819-4c65-a46a-d0796d6c9085"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2024-01-25 06:11:57--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt\r\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.109.133, ...\r\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\r\n",
            "HTTP request sent, awaiting response... 200 OK\r\n",
            "Length: 1115394 (1.1M) [text/plain]\r\n",
            "Saving to: ‘input.txt.2’\r\n",
            "\r\n",
            "\rinput.txt.2           0%[                    ]       0  --.-KB/s               \rinput.txt.2         100%[===================>]   1.06M  --.-KB/s    in 0.04s   \r\n",
            "\r\n",
            "2024-01-25 06:11:57 (30.2 MB/s) - ‘input.txt.2’ saved [1115394/1115394]\r\n",
            "\r\n"
          ]
        }
      ],
      "source": [
        "# Downloading the tiny shakespeare dataset\n",
        "!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "192e830a-762d-4573-9484-70a58deb1fec",
          "showTitle": false,
          "title": ""
        },
        "id": "3sPAL1AKkRJW"
      },
      "outputs": [],
      "source": [
        "# read it in to inspect it\n",
        "with open('input.txt', 'r', encoding='utf-8') as f:\n",
        "    text = f.read()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "2d7181b7-f5e5-4ab5-bdd8-74c507c798ad",
          "showTitle": false,
          "title": ""
        },
        "id": "wNkF3RYLkRJX",
        "outputId": "671be677-be3d-49e0-d391-569559a99ae2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "length of dataset in characters:  1115394\n"
          ]
        }
      ],
      "source": [
        "print(\"length of dataset in characters: \", len(text))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "68032e07-8625-4750-a340-bc8f4eed2458",
          "showTitle": false,
          "title": ""
        },
        "id": "AHIwr-yxkRJX",
        "outputId": "fa10ee0f-c8cb-4ba1-9fb1-79f9fbf336fe"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "First Citizen:\n",
            "Before we proceed any further, hear me speak.\n",
            "\n",
            "All:\n",
            "Speak, speak.\n",
            "\n",
            "First Citizen:\n",
            "You are all resolved rather to die than to famish?\n",
            "\n",
            "All:\n",
            "Resolved. resolved.\n",
            "\n",
            "First Citizen:\n",
            "First, you know Caius Marcius is chief enemy to the people.\n",
            "\n",
            "All:\n",
            "We know't, we know't.\n",
            "\n",
            "First Citizen:\n",
            "Let us kill him, and we'll have corn at our own price.\n",
            "Is't a verdict?\n",
            "\n",
            "All:\n",
            "No more talking on't; let it be done: away, away!\n",
            "\n",
            "Second Citizen:\n",
            "One word, good citizens.\n",
            "\n",
            "First Citizen:\n",
            "We are accounted poor citizens, the patricians good.\n",
            "What authority surfeits on would relieve us: if they\n",
            "would yield us but the superfluity, while it were\n",
            "wholesome, we might guess they relieved us humanely;\n",
            "but they think we are too dear: the leanness that\n",
            "afflicts us, the object of our misery, is as an\n",
            "inventory to particularise their abundance; our\n",
            "sufferance is a gain to them Let us revenge this with\n",
            "our pikes, ere we become rakes: for the gods know I\n",
            "speak this in hunger for bread, not in thirst for revenge.\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# let's look at the first 1000 characters\n",
        "print(text[:1000])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "b6995ad6-c9ac-4a21-9da0-ebbd3273c991",
          "showTitle": false,
          "title": ""
        },
        "id": "DHGayz7mkRJY",
        "outputId": "54b52e55-b6b9-4bdc-974a-5c48a43e1497"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
            "65\n"
          ]
        }
      ],
      "source": [
        "# here are all the unique characters that occur in this text\n",
        "chars = sorted(list(set(text)))\n",
        "vocab_size = len(chars)\n",
        "print(''.join(chars))\n",
        "print(vocab_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "43002fa3-ffd3-416c-9aaf-0a03b19c7bc1",
          "showTitle": false,
          "title": ""
        },
        "id": "pzn11WJckRJY",
        "outputId": "cd7b3ff1-680b-4036-b3c8-b597a361aeb3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
            "hii there\n"
          ]
        }
      ],
      "source": [
        "# create a mapping from characters to integers\n",
        "stoi = { ch:i for i,ch in enumerate(chars) }\n",
        "itos = { i:ch for i,ch in enumerate(chars) }\n",
        "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
        "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
        "\n",
        "print(encode(\"hii there\"))\n",
        "print(decode(encode(\"hii there\")))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "b4609fc4-09c7-4a39-8367-e9ee39d440ed",
          "showTitle": false,
          "title": ""
        },
        "id": "YbBGz0O2kRJY",
        "outputId": "49e94dde-6e1f-4a9e-a18b-9b8f1c51de86"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "torch.Size([1115394]) torch.int64\n",
            "tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,\n",
            "        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,\n",
            "         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,\n",
            "        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,\n",
            "         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,\n",
            "        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,\n",
            "         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,\n",
            "        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,\n",
            "        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,\n",
            "         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,\n",
            "         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,\n",
            "        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,\n",
            "        47, 59, 57,  1, 47, 57,  1, 41, 46, 47, 43, 44,  1, 43, 52, 43, 51, 63,\n",
            "         1, 58, 53,  1, 58, 46, 43,  1, 54, 43, 53, 54, 50, 43,  8,  0,  0, 13,\n",
            "        50, 50, 10,  0, 35, 43,  1, 49, 52, 53, 61,  5, 58,  6,  1, 61, 43,  1,\n",
            "        49, 52, 53, 61,  5, 58,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58,\n",
            "        47, 64, 43, 52, 10,  0, 24, 43, 58,  1, 59, 57,  1, 49, 47, 50, 50,  1,\n",
            "        46, 47, 51,  6,  1, 39, 52, 42,  1, 61, 43,  5, 50, 50,  1, 46, 39, 60,\n",
            "        43,  1, 41, 53, 56, 52,  1, 39, 58,  1, 53, 59, 56,  1, 53, 61, 52,  1,\n",
            "        54, 56, 47, 41, 43,  8,  0, 21, 57,  5, 58,  1, 39,  1, 60, 43, 56, 42,\n",
            "        47, 41, 58, 12,  0,  0, 13, 50, 50, 10,  0, 26, 53,  1, 51, 53, 56, 43,\n",
            "         1, 58, 39, 50, 49, 47, 52, 45,  1, 53, 52,  5, 58, 11,  1, 50, 43, 58,\n",
            "         1, 47, 58,  1, 40, 43,  1, 42, 53, 52, 43, 10,  1, 39, 61, 39, 63,  6,\n",
            "         1, 39, 61, 39, 63,  2,  0,  0, 31, 43, 41, 53, 52, 42,  1, 15, 47, 58,\n",
            "        47, 64, 43, 52, 10,  0, 27, 52, 43,  1, 61, 53, 56, 42,  6,  1, 45, 53,\n",
            "        53, 42,  1, 41, 47, 58, 47, 64, 43, 52, 57,  8,  0,  0, 18, 47, 56, 57,\n",
            "        58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 35, 43,  1, 39, 56, 43,  1,\n",
            "        39, 41, 41, 53, 59, 52, 58, 43, 42,  1, 54, 53, 53, 56,  1, 41, 47, 58,\n",
            "        47, 64, 43, 52, 57,  6,  1, 58, 46, 43,  1, 54, 39, 58, 56, 47, 41, 47,\n",
            "        39, 52, 57,  1, 45, 53, 53, 42,  8,  0, 35, 46, 39, 58,  1, 39, 59, 58,\n",
            "        46, 53, 56, 47, 58, 63,  1, 57, 59, 56, 44, 43, 47, 58, 57,  1, 53, 52,\n",
            "         1, 61, 53, 59, 50, 42,  1, 56, 43, 50, 47, 43, 60, 43,  1, 59, 57, 10,\n",
            "         1, 47, 44,  1, 58, 46, 43, 63,  0, 61, 53, 59, 50, 42,  1, 63, 47, 43,\n",
            "        50, 42,  1, 59, 57,  1, 40, 59, 58,  1, 58, 46, 43,  1, 57, 59, 54, 43,\n",
            "        56, 44, 50, 59, 47, 58, 63,  6,  1, 61, 46, 47, 50, 43,  1, 47, 58,  1,\n",
            "        61, 43, 56, 43,  0, 61, 46, 53, 50, 43, 57, 53, 51, 43,  6,  1, 61, 43,\n",
            "         1, 51, 47, 45, 46, 58,  1, 45, 59, 43, 57, 57,  1, 58, 46, 43, 63,  1,\n",
            "        56, 43, 50, 47, 43, 60, 43, 42,  1, 59, 57,  1, 46, 59, 51, 39, 52, 43,\n",
            "        50, 63, 11,  0, 40, 59, 58,  1, 58, 46, 43, 63,  1, 58, 46, 47, 52, 49,\n",
            "         1, 61, 43,  1, 39, 56, 43,  1, 58, 53, 53,  1, 42, 43, 39, 56, 10,  1,\n",
            "        58, 46, 43,  1, 50, 43, 39, 52, 52, 43, 57, 57,  1, 58, 46, 39, 58,  0,\n",
            "        39, 44, 44, 50, 47, 41, 58, 57,  1, 59, 57,  6,  1, 58, 46, 43,  1, 53,\n",
            "        40, 48, 43, 41, 58,  1, 53, 44,  1, 53, 59, 56,  1, 51, 47, 57, 43, 56,\n",
            "        63,  6,  1, 47, 57,  1, 39, 57,  1, 39, 52,  0, 47, 52, 60, 43, 52, 58,\n",
            "        53, 56, 63,  1, 58, 53,  1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,\n",
            "        57, 43,  1, 58, 46, 43, 47, 56,  1, 39, 40, 59, 52, 42, 39, 52, 41, 43,\n",
            "        11,  1, 53, 59, 56,  0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43,  1, 47,\n",
            "        57,  1, 39,  1, 45, 39, 47, 52,  1, 58, 53,  1, 58, 46, 43, 51,  1, 24,\n",
            "        43, 58,  1, 59, 57,  1, 56, 43, 60, 43, 52, 45, 43,  1, 58, 46, 47, 57,\n",
            "         1, 61, 47, 58, 46,  0, 53, 59, 56,  1, 54, 47, 49, 43, 57,  6,  1, 43,\n",
            "        56, 43,  1, 61, 43,  1, 40, 43, 41, 53, 51, 43,  1, 56, 39, 49, 43, 57,\n",
            "        10,  1, 44, 53, 56,  1, 58, 46, 43,  1, 45, 53, 42, 57,  1, 49, 52, 53,\n",
            "        61,  1, 21,  0, 57, 54, 43, 39, 49,  1, 58, 46, 47, 57,  1, 47, 52,  1,\n",
            "        46, 59, 52, 45, 43, 56,  1, 44, 53, 56,  1, 40, 56, 43, 39, 42,  6,  1,\n",
            "        52, 53, 58,  1, 47, 52,  1, 58, 46, 47, 56, 57, 58,  1, 44, 53, 56,  1,\n",
            "        56, 43, 60, 43, 52, 45, 43,  8,  0,  0])\n"
          ]
        }
      ],
      "source": [
        "# let's now encode the entire text dataset and store it into a torch.Tensor\n",
        "data = torch.tensor(encode(text), dtype=torch.long)\n",
        "print(data.shape, data.dtype)\n",
        "print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "88f7cb0d-02ff-42b0-92a5-a505dc3f8f25",
          "showTitle": false,
          "title": ""
        },
        "id": "hoLIeA7YkRJZ"
      },
      "outputs": [],
      "source": [
        "# Let's now split up the data into train and validation sets\n",
        "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
        "train_data = data[:n]\n",
        "val_data = data[n:]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "6b554ddf-50f4-441b-8acf-10b81a508b7e",
          "showTitle": false,
          "title": ""
        },
        "id": "VY55nr6EkRJZ",
        "outputId": "1a5685e9-66bb-40ed-f19a-f0cb09e8e51c"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])"
            ]
          },
          "execution_count": 26,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "block_size = 8\n",
        "train_data[:block_size+1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "22ba4512-309d-4895-a908-2ef3efa317bc",
          "showTitle": false,
          "title": ""
        },
        "id": "5YbgrB9HkRJZ",
        "outputId": "05577a5f-10d5-495b-82e8-c3278b4fdea0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "when input is tensor([18]) the target: 47\n",
            "when input is tensor([18, 47]) the target: 56\n",
            "when input is tensor([18, 47, 56]) the target: 57\n",
            "when input is tensor([18, 47, 56, 57]) the target: 58\n",
            "when input is tensor([18, 47, 56, 57, 58]) the target: 1\n",
            "when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15\n",
            "when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47\n",
            "when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58\n"
          ]
        }
      ],
      "source": [
        "x = train_data[:block_size]\n",
        "y = train_data[1:block_size+1]\n",
        "for t in range(block_size):\n",
        "    context = x[:t+1]\n",
        "    target = y[t]\n",
        "    print(f\"when input is {context} the target: {target}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "bf386bff-0f63-4358-82fc-6c7d02c37321",
          "showTitle": false,
          "title": ""
        },
        "id": "Oaxhage8kRJZ"
      },
      "outputs": [],
      "source": [
        "batch_size = 4 # how many independent sequences will we process in parallel?\n",
        "block_size = 8 # what is the maximum context length for predictions?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "99acd85c-233f-4f2d-a062-028dbcde9960",
          "showTitle": false,
          "title": ""
        },
        "id": "HfpkIUNdkRJZ",
        "outputId": "1ad2bc10-2b4a-458b-a6ee-a169728ec90d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([250930, 237205, 974116, 383898])"
            ]
          },
          "execution_count": 29,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "ix = torch.randint(len(data) - block_size, (batch_size,))\n",
        "ix"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "e46dc826-9f39-4aed-b2d2-f9ea401136de",
          "showTitle": false,
          "title": ""
        },
        "id": "faoGVPG3kRJa",
        "outputId": "dcf8b139-d8f8-4565-a962-62d094c5da9f"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[42,  1, 58, 46, 59, 57,  1, 21],\n",
              "        [54, 56, 47, 43, 57, 58, 11,  0],\n",
              "        [49, 47, 52, 45, 12,  1, 58, 46],\n",
              "        [58, 46, 53, 59, 58,  1, 56, 43]])"
            ]
          },
          "execution_count": 30,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "x = torch.stack([data[i:i+block_size] for i in ix])\n",
        "y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
        "x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "2886aedf-200e-40bd-9a9d-4658cf6c509b",
          "showTitle": false,
          "title": ""
        },
        "id": "hkllYFCPkRJa",
        "outputId": "9350713c-be0f-492a-c873-bd279226668d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[ 1, 58, 46, 59, 57,  1, 21,  1],\n",
              "        [56, 47, 43, 57, 58, 11,  0, 37],\n",
              "        [47, 52, 45, 12,  1, 58, 46, 53],\n",
              "        [46, 53, 59, 58,  1, 56, 43, 42]])"
            ]
          },
          "execution_count": 31,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "a486fc04-ed29-456f-918b-5f8395e455cb",
          "showTitle": false,
          "title": ""
        },
        "id": "tWrajECBkRJa"
      },
      "source": [
        "The following code block clearly shows the autoregressive nature of the prediction and how the context is a rolling windows over a 1 dimentional arrangement of tokens (characters in this case)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "49a86e10-ac37-4b92-8f18-775cd4853fdc",
          "showTitle": false,
          "title": ""
        },
        "id": "xjgtxxztkRJa",
        "outputId": "1be0ebc2-b22d-4136-e77d-ecdafded66d2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "inputs:\n",
            "torch.Size([4, 8])\n",
            "tensor([[ 6,  0, 14, 43, 44, 53, 56, 43],\n",
            "        [39,  1, 42, 59, 43,  1, 39, 52],\n",
            "        [47, 41, 43,  1, 39, 52, 42,  1],\n",
            "        [53, 44,  1, 50, 43, 58,  1, 58]])\n",
            "targets:\n",
            "torch.Size([4, 8])\n",
            "tensor([[ 0, 14, 43, 44, 53, 56, 43,  1],\n",
            "        [ 1, 42, 59, 43,  1, 39, 52, 42],\n",
            "        [41, 43,  1, 39, 52, 42,  1, 42],\n",
            "        [44,  1, 50, 43, 58,  1, 58, 46]])\n",
            "----\n",
            "when input is [6] the target: 0\n",
            "when input is [6, 0] the target: 14\n",
            "when input is [6, 0, 14] the target: 43\n",
            "when input is [6, 0, 14, 43] the target: 44\n",
            "when input is [6, 0, 14, 43, 44] the target: 53\n",
            "when input is [6, 0, 14, 43, 44, 53] the target: 56\n",
            "when input is [6, 0, 14, 43, 44, 53, 56] the target: 43\n",
            "when input is [6, 0, 14, 43, 44, 53, 56, 43] the target: 1\n",
            "when input is [39] the target: 1\n",
            "when input is [39, 1] the target: 42\n",
            "when input is [39, 1, 42] the target: 59\n",
            "when input is [39, 1, 42, 59] the target: 43\n",
            "when input is [39, 1, 42, 59, 43] the target: 1\n",
            "when input is [39, 1, 42, 59, 43, 1] the target: 39\n",
            "when input is [39, 1, 42, 59, 43, 1, 39] the target: 52\n",
            "when input is [39, 1, 42, 59, 43, 1, 39, 52] the target: 42\n",
            "when input is [47] the target: 41\n",
            "when input is [47, 41] the target: 43\n",
            "when input is [47, 41, 43] the target: 1\n",
            "when input is [47, 41, 43, 1] the target: 39\n",
            "when input is [47, 41, 43, 1, 39] the target: 52\n",
            "when input is [47, 41, 43, 1, 39, 52] the target: 42\n",
            "when input is [47, 41, 43, 1, 39, 52, 42] the target: 1\n",
            "when input is [47, 41, 43, 1, 39, 52, 42, 1] the target: 42\n",
            "when input is [53] the target: 44\n",
            "when input is [53, 44] the target: 1\n",
            "when input is [53, 44, 1] the target: 50\n",
            "when input is [53, 44, 1, 50] the target: 43\n",
            "when input is [53, 44, 1, 50, 43] the target: 58\n",
            "when input is [53, 44, 1, 50, 43, 58] the target: 1\n",
            "when input is [53, 44, 1, 50, 43, 58, 1] the target: 58\n",
            "when input is [53, 44, 1, 50, 43, 58, 1, 58] the target: 46\n"
          ]
        }
      ],
      "source": [
        "def get_batch(split):\n",
        "    # generate a small batch of data of inputs x and targets y\n",
        "    data = train_data if split == 'train' else val_data\n",
        "    ix = torch.randint(len(data) - block_size, (batch_size,))\n",
        "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
        "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
        "    return x, y\n",
        "\n",
        "xb, yb = get_batch('train')\n",
        "print('inputs:')\n",
        "print(xb.shape)\n",
        "print(xb)\n",
        "print('targets:')\n",
        "print(yb.shape)\n",
        "print(yb)\n",
        "\n",
        "print('----')\n",
        "\n",
        "for b in range(batch_size): # batch dimension\n",
        "    for t in range(block_size): # time dimension\n",
        "        context = xb[b, :t+1]\n",
        "        target = yb[b,t]\n",
        "        print(f\"when input is {context.tolist()} the target: {target}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "dde3273f-0519-4108-ba84-dfd99e020722",
          "showTitle": false,
          "title": ""
        },
        "id": "RlON_gNikRJa"
      },
      "source": [
        "### Understanding the intuition of Causal Scaled Dot Product Self Attention\n",
        "\n",
        "This code is borrowed from Andrej Karpathy's excellent makemore repository linked in the repo"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "e435d0cf-1383-446a-9026-cd80b4266019",
          "showTitle": false,
          "title": ""
        },
        "id": "uBVWP40SkRJa"
      },
      "source": [
        "![scaled dot product self attention](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/self_attention.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "97660589-1719-48c0-ad6f-4f7c2888348a",
          "showTitle": false,
          "title": ""
        },
        "id": "i-jFgMntkRJa"
      },
      "source": [
        "The provided code demonstrates self-attention's mechanics and fundamental concepts, specifically focusing on the classic scaled dot product self-attention. In this variant, the query, key, and value matrices all originate from the same input sequence. To ensure the integrity of the autoregressive language generation process, particularly in a decoder-only model, the code implements masking. This masking technique is crucial as it obscures any information following the current token's position, thereby directing the model's attention to only the preceding parts of the sequence. Such an attention mechanism is known as causal self-attention. It's important to note that the Sparse Mixture of Experts model isn't restricted to decoder-only Transformer architectures. In fact, much of the significant work in this field, particularly that by Shazeer et al, revolves around the T5 architecture, which encompasses both encoder and decoder components in the Transformer model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "6f82ca41-a301-4a92-aed9-ba7ac3a2bf88",
          "showTitle": false,
          "title": ""
        },
        "id": "lxMSgZWGkRJb",
        "outputId": "8d968403-a14f-4f3c-f2ce-9099a2b34f1e"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "torch.Size([4, 8, 16])"
            ]
          },
          "execution_count": 33,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "torch.manual_seed(1337)\n",
        "B,T,C = 4,8,32 # batch, time, channels\n",
        "x = torch.randn(B,T,C)\n",
        "\n",
        "# let's see a single Head perform self-attention\n",
        "head_size = 16\n",
        "key = nn.Linear(C, head_size, bias=False)\n",
        "query = nn.Linear(C, head_size, bias=False)\n",
        "value = nn.Linear(C, head_size, bias=False)\n",
        "k = key(x)   # (B, T, 16)\n",
        "q = query(x) # (B, T, 16)\n",
        "wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
        "\n",
        "tril = torch.tril(torch.ones(T, T))\n",
        "#wei = torch.zeros((T,T))\n",
        "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
        "wei = F.softmax(wei, dim=-1) #B,T,T\n",
        "\n",
        "v = value(x) #B,T,H\n",
        "out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)\n",
        "#The output from this final matrix product is subsequently passsed through a linear layer as shown in the diagram above\n",
        "\n",
        "out.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "49c278ec-19db-4c5d-b4a3-3bdc45c5a443",
          "showTitle": false,
          "title": ""
        },
        "id": "cA3iXggEkRJb"
      },
      "source": [
        "Generalizing and Modularizing code for causal self attention and multi-head causal self attention. Multi-head self attention applied multiple attention heads in parallel, each focusing on a separate section of the channel (the embedding dimension)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "608c6c9f-fb93-43ed-9580-5e782fd90d61",
          "showTitle": false,
          "title": ""
        },
        "id": "909nX3PHkRJb"
      },
      "outputs": [],
      "source": [
        "#Causal scaled dot product self-Attention Head\n",
        "\n",
        "n_embd = 64\n",
        "n_head = 4\n",
        "n_layer = 4\n",
        "head_size = 16\n",
        "dropout = 0.1\n",
        "\n",
        "class Head(nn.Module):\n",
        "    \"\"\" one head of self-attention \"\"\"\n",
        "\n",
        "    def __init__(self, head_size):\n",
        "        super().__init__()\n",
        "        self.key = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.query = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.value = nn.Linear(n_embd, head_size, bias=False)\n",
        "        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        B,T,C = x.shape\n",
        "        k = self.key(x)   # (B,T,C)\n",
        "        q = self.query(x) # (B,T,C)\n",
        "        # compute attention scores (\"affinities\")\n",
        "        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
        "        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
        "        wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
        "        wei = self.dropout(wei)\n",
        "        # perform the weighted aggregation of the values\n",
        "        v = self.value(x) # (B,T,C)\n",
        "        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
        "        return out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "6e8b31af-f45a-4066-8288-fb0d9c8e2aff",
          "showTitle": false,
          "title": ""
        },
        "id": "T3MoVK_WkRJb"
      },
      "outputs": [],
      "source": [
        "#Multi-Headed Self Attention\n",
        "class MultiHeadAttention(nn.Module):\n",
        "    \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
        "\n",
        "    def __init__(self, num_heads, head_size):\n",
        "        super().__init__()\n",
        "        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
        "        self.proj = nn.Linear(n_embd, n_embd)\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
        "        out = self.dropout(self.proj(out))\n",
        "        return out\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "16267e9a-008b-46e3-82ce-2ae41396a1a1",
          "showTitle": false,
          "title": ""
        },
        "id": "T-w53_mSkRJb",
        "outputId": "052e4478-b284-43c7-924c-fcf13deea320"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "torch.Size([4, 8, 64])"
            ]
          },
          "execution_count": 68,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "#Confirming that what's output from multi head attention is the original embedding size\n",
        "B,T,C = 4,8,64 # batch, time, channels\n",
        "x = torch.randn(B,T,C)\n",
        "mha = MultiHeadAttention(4,16)\n",
        "mha(x).shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "5f7ff128-7fe5-4a91-b9f2-208e2132e505",
          "showTitle": false,
          "title": ""
        },
        "id": "wNlJTtfhkRJb"
      },
      "source": [
        "### Creating an Expert module i.e. a simple Multi Layer Perceptron"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "f6e422a5-57c1-4b2f-b7b9-2757e109848a",
          "showTitle": false,
          "title": ""
        },
        "id": "zv3fGRpbkRJb"
      },
      "source": [
        "In the Sparse Mixture of Experts (MoE) architecture, the self-attention mechanism within each transformer block remains unchanged. However, a notable alteration occurs in the structure of each block: the standard feed-forward neural network is replaced with several sparsely activated feed-forward networks, known as experts. \"Sparse activation\" refers to the process where each token in the sequence is routed to only a limited number of these experts – typically one or two – out of the total pool available. This modification allows for specialized processing of different parts of the input data, enabling the model to handle a wider range of complexities efficiently."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "efe9fdcc-82eb-4047-9233-ad3cfe8759b1",
          "showTitle": false,
          "title": ""
        },
        "id": "7Kz0Y_P0kRJc"
      },
      "source": [
        "![experts](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/experts.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "a2f0f382-ab4a-45e1-9dce-27ae0d3da641",
          "showTitle": false,
          "title": ""
        },
        "id": "a-9CYWXgkRJc"
      },
      "outputs": [],
      "source": [
        "#Expert module\n",
        "class Expert(nn.Module):\n",
        "    \"\"\" An MLP is a simple linear layer followed by a non-linearity i.e. each Expert \"\"\"\n",
        "\n",
        "    def __init__(self, n_embd):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(n_embd, 4 * n_embd),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(4 * n_embd, n_embd),\n",
        "            nn.Dropout(dropout),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "a7764385-26e9-4d75-9aa7-ce011023e24e",
          "showTitle": false,
          "title": ""
        },
        "id": "qderdEuykRJc"
      },
      "source": [
        "### Top-k Gating Intuition through an Example"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "d3fca4df-4c47-4e9a-98cd-08cf8ccf7726",
          "showTitle": false,
          "title": ""
        },
        "id": "VxJv5y44kRJc"
      },
      "source": [
        "![top k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/topk.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "8e494b86-cdb2-4f2a-8824-5fa2ef4b2606",
          "showTitle": false,
          "title": ""
        },
        "id": "n6DuhY0DkRJc"
      },
      "source": [
        "The gating network, also known as the router, determines which expert network receives the output for each token from the multi-head attention. Let's consider a simple example: suppose there are 4 experts, and the token is to be routed to the top 2 experts. Initially, we input the token into the gating network through a linear layer. This layer projects the input tensor from a shape of (2, 4, 32) — representing (Batch size, Tokens, n_embed, where n_embed is the channel dimension of the input) — to a new shape of (2, 4, 4), which corresponds to (Batch size, Tokens, num_experts), where num_experts is the count of expert networks. Following this, we determine the top k=2 highest values and their respective indices along the last dimension."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "621916ff-2290-4e2f-9fd7-5181ed98d540",
          "showTitle": false,
          "title": ""
        },
        "id": "pNAuFDDvkRJc",
        "outputId": "a9d313d5-1dd1-4df8-9c9a-dc6d0b90d4ed"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(tensor([[[ 0.0246, -0.0190],\n",
              "          [ 0.1991,  0.1513],\n",
              "          [ 0.9749,  0.7185],\n",
              "          [ 0.4406, -0.8357]],\n",
              " \n",
              "         [[ 0.6206, -0.0503],\n",
              "          [ 0.8635,  0.3784],\n",
              "          [ 0.6828,  0.5972],\n",
              "          [ 0.4743,  0.3420]]], grad_fn=<TopkBackward0>),\n",
              " tensor([[[2, 3],\n",
              "          [2, 1],\n",
              "          [3, 1],\n",
              "          [2, 1]],\n",
              " \n",
              "         [[0, 2],\n",
              "          [0, 3],\n",
              "          [3, 2],\n",
              "          [3, 0]]]))"
            ]
          },
          "execution_count": 69,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "#Understanding how gating works\n",
        "num_experts = 4\n",
        "top_k=2\n",
        "n_embed=32\n",
        "\n",
        "\n",
        "#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2\n",
        "mh_output = torch.randn(2, 4, n_embed)\n",
        "\n",
        "topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)\n",
        "\n",
        "logits = topkgate_linear(mh_output)\n",
        "top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts\n",
        "top_k_logits, top_k_indices"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "0f135ff7-0aa3-4b6d-ab5e-42399c48427b",
          "showTitle": false,
          "title": ""
        },
        "id": "EKwAyJxrkRJd"
      },
      "source": [
        "Obtain the sparse gating output by only keeping the top k values in their respective index along the last dimension. Fill the rest with '-inf' and pass through a softmax activation. This pushes '-inf' values to zero, makes the top two values more accentuated and sum to 1. This summation to 1 helps with the weighting of expert outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "735e160a-ef1e-424d-b6d9-09f63ea99ec1",
          "showTitle": false,
          "title": ""
        },
        "id": "IiVejzOpkRJd",
        "outputId": "2ff61217-56f8-4835-d310-2169240fb4a1"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[[   -inf,    -inf,  0.0246, -0.0190],\n",
              "         [   -inf,  0.1513,  0.1991,    -inf],\n",
              "         [   -inf,  0.7185,    -inf,  0.9749],\n",
              "         [   -inf, -0.8357,  0.4406,    -inf]],\n",
              "\n",
              "        [[ 0.6206,    -inf, -0.0503,    -inf],\n",
              "         [ 0.8635,    -inf,    -inf,  0.3784],\n",
              "         [   -inf,    -inf,  0.5972,  0.6828],\n",
              "         [ 0.3420,    -inf,    -inf,  0.4743]]], grad_fn=<ScatterBackward0>)"
            ]
          },
          "execution_count": 70,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.\n",
        "sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)\n",
        "sparse_logits"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "9146e6f9-4eee-4a8b-8338-55072719ed59",
          "showTitle": false,
          "title": ""
        },
        "id": "HFgRxDF4kRJh",
        "outputId": "cde71bbf-ba67-4972-eef1-2fba48a106d5"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[[0.0000, 0.0000, 0.5109, 0.4891],\n",
              "         [0.0000, 0.4881, 0.5119, 0.0000],\n",
              "         [0.0000, 0.4362, 0.0000, 0.5638],\n",
              "         [0.0000, 0.2182, 0.7818, 0.0000]],\n",
              "\n",
              "        [[0.6617, 0.0000, 0.3383, 0.0000],\n",
              "         [0.6190, 0.0000, 0.0000, 0.3810],\n",
              "         [0.0000, 0.0000, 0.4786, 0.5214],\n",
              "         [0.4670, 0.0000, 0.0000, 0.5330]]], grad_fn=<SoftmaxBackward0>)"
            ]
          },
          "execution_count": 71,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "gating_output= F.softmax(sparse_logits, dim=-1)\n",
        "gating_output"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "fe558b12-e443-4b62-9a85-c59120456352",
          "showTitle": false,
          "title": ""
        },
        "id": "dGJrq2uqkRJh"
      },
      "source": [
        "### Generalizing and Modularizing above code and adding noisy top-k Gating for load balancing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "45516b59-d814-4853-a34e-d36aae9f04eb",
          "showTitle": false,
          "title": ""
        },
        "id": "TKp4DqwYkRJh"
      },
      "outputs": [],
      "source": [
        "# First define the top k router module\n",
        "class TopkRouter(nn.Module):\n",
        "    def __init__(self, n_embed, num_experts, top_k):\n",
        "        super(TopkRouter, self).__init__()\n",
        "        self.top_k = top_k\n",
        "        self.linear =nn.Linear(n_embed, num_experts)\n",
        "\n",
        "    def forward(self, mh_ouput):\n",
        "        # mh_ouput is the output tensor from multihead self attention block\n",
        "        logits = self.linear(mh_output)\n",
        "        top_k_logits, indices = logits.topk(self.top_k, dim=-1)\n",
        "        zeros = torch.full_like(logits, float('-inf'))\n",
        "        sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
        "        router_output = F.softmax(sparse_logits, dim=-1)\n",
        "        return router_output, indices\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "c500844f-0866-4bbf-acef-c0c1d4979721",
          "showTitle": false,
          "title": ""
        },
        "id": "KjkouzwkkRJh",
        "outputId": "4e5cd861-4f75-4e98-c3cf-5697fb5f7451"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(torch.Size([2, 4, 4]),\n",
              " tensor([[[0.4359, 0.0000, 0.5641, 0.0000],\n",
              "          [0.6075, 0.0000, 0.3925, 0.0000],\n",
              "          [0.6916, 0.3084, 0.0000, 0.0000],\n",
              "          [0.0000, 0.0000, 0.4342, 0.5658]],\n",
              " \n",
              "         [[0.0000, 0.4527, 0.5473, 0.0000],\n",
              "          [0.5313, 0.4687, 0.0000, 0.0000],\n",
              "          [0.0000, 0.5572, 0.4428, 0.0000],\n",
              "          [0.0000, 0.6293, 0.3707, 0.0000]]], grad_fn=<SoftmaxBackward0>),\n",
              " tensor([[[2, 0],\n",
              "          [0, 2],\n",
              "          [0, 1],\n",
              "          [3, 2]],\n",
              " \n",
              "         [[2, 1],\n",
              "          [0, 1],\n",
              "          [1, 2],\n",
              "          [1, 2]]]))"
            ]
          },
          "execution_count": 72,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "#Testing this out:\n",
        "num_experts = 4\n",
        "top_k = 2\n",
        "n_embd = 32\n",
        "\n",
        "mh_output = torch.randn(2, 4, n_embd)  # Example input\n",
        "top_k_gate = TopkRouter(n_embd, num_experts, top_k)\n",
        "gating_output, indices = top_k_gate(mh_output)\n",
        "gating_output.shape, gating_output, indices\n",
        "#And it works!!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "9fa02d0c-3688-4d01-811d-d0b2b851ab33",
          "showTitle": false,
          "title": ""
        },
        "id": "tAouN-GwkRJi"
      },
      "source": [
        "Althought the mixtral paper released recently does not make any mention of it, I believe Noisy top-k Gating is an important tool in training MoE models. Essentially, you don't want all the tokens to be sent to the same set of 'favored' experts. You want a fine balance of exploitation and exploration. For this purpose, to load balance, it is helpful to add standard normal noise to the logits from the gating linear layer. This makes training more efficient"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "e05b3306-b89f-4ebc-901b-f16398a925c2",
          "showTitle": false,
          "title": ""
        },
        "id": "aWeueE83kRJi"
      },
      "source": [
        "![noisy top-k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/noisytopkgating.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "dda4d805-373c-48f7-9037-da08fbc06e64",
          "showTitle": false,
          "title": ""
        },
        "id": "ZBrN-w3JkRJi"
      },
      "outputs": [],
      "source": [
        "#Changing the above to accomodate noisy top-k gating\n",
        "class NoisyTopkRouter(nn.Module):\n",
        "    def __init__(self, n_embed, num_experts, top_k):\n",
        "        super(NoisyTopkRouter, self).__init__()\n",
        "        self.top_k = top_k\n",
        "        #layer for router logits\n",
        "        self.topkroute_linear = nn.Linear(n_embed, num_experts)\n",
        "        self.noise_linear =nn.Linear(n_embed, num_experts)\n",
        "\n",
        "\n",
        "    def forward(self, mh_output):\n",
        "        # mh_ouput is the output tensor from multihead self attention block\n",
        "        logits = self.topkroute_linear(mh_output)\n",
        "\n",
        "        #Noise logits\n",
        "        noise_logits = self.noise_linear(mh_output)\n",
        "\n",
        "        #Adding scaled unit gaussian noise to the logits\n",
        "        noise = torch.randn_like(logits)*F.softplus(noise_logits)\n",
        "        noisy_logits = logits + noise\n",
        "\n",
        "        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)\n",
        "        zeros = torch.full_like(noisy_logits, float('-inf'))\n",
        "        sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
        "        router_output = F.softmax(sparse_logits, dim=-1)\n",
        "        return router_output, indices"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "a01a9d6b-fedb-427d-b0da-c3b2a75a8643",
          "showTitle": false,
          "title": ""
        },
        "id": "7Q6KcH9AkRJi",
        "outputId": "7aea5328-ba96-4c98-b550-042ecf440700"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(torch.Size([2, 4, 8]),\n",
              " tensor([[[0.0000, 0.0000, 0.0000, 0.5903, 0.0000, 0.0000, 0.0000, 0.4097],\n",
              "          [0.6794, 0.0000, 0.0000, 0.0000, 0.0000, 0.3206, 0.0000, 0.0000],\n",
              "          [0.0000, 0.0000, 0.0000, 0.0000, 0.2743, 0.7257, 0.0000, 0.0000],\n",
              "          [0.0000, 0.0000, 0.1950, 0.0000, 0.8050, 0.0000, 0.0000, 0.0000]],\n",
              " \n",
              "         [[0.0000, 0.0000, 0.3905, 0.0000, 0.0000, 0.6095, 0.0000, 0.0000],\n",
              "          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3606, 0.6394],\n",
              "          [0.5243, 0.0000, 0.4757, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
              "          [0.0000, 0.0000, 0.0000, 0.5627, 0.0000, 0.0000, 0.4373, 0.0000]]],\n",
              "        grad_fn=<SoftmaxBackward0>),\n",
              " tensor([[[3, 7],\n",
              "          [0, 5],\n",
              "          [5, 4],\n",
              "          [4, 2]],\n",
              " \n",
              "         [[5, 2],\n",
              "          [7, 6],\n",
              "          [0, 2],\n",
              "          [3, 6]]]))"
            ]
          },
          "execution_count": 53,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "#Testing this out, again:\n",
        "num_experts = 8\n",
        "top_k = 2\n",
        "n_embd = 16\n",
        "\n",
        "mh_output = torch.randn(2, 4, n_embd)  # Example input\n",
        "noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)\n",
        "gating_output, indices = noisy_top_k_gate(mh_output)\n",
        "gating_output.shape, gating_output, indices\n",
        "#It works!!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "076fa004-a165-42a7-b729-0bca8ad39418",
          "showTitle": false,
          "title": ""
        },
        "id": "XyKjpR-dkRJi"
      },
      "source": [
        "\n",
        "### Creating a sparse Mixture of Experts module\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "6747b8de-0086-4cb0-8fbd-46ee95457eb9",
          "showTitle": false,
          "title": ""
        },
        "id": "UsRCy7i3kRJi"
      },
      "source": [
        "The primary aspect of this process involves the gating network's output. After acquiring these results, the top k values are selectively multiplied with the outputs from the corresponding top-k experts for a given token. This selective multiplication forms a weighted sum, which constitutes the SparseMoe block's output. The critical and challenging part of this process is to avoid unnecessary multiplications. It's essential to conduct forward passes only for the top_k experts and then compute this weighted sum. Performing forward passes for each expert would defeat the purpose of employing a sparse MoE, as it would no longer be sparse."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "d6809b3f-4be9-4859-b39e-24fcdd6c8d86",
          "showTitle": false,
          "title": ""
        },
        "id": "7dDUHU_IkRJi"
      },
      "outputs": [],
      "source": [
        "class SparseMoE(nn.Module):\n",
        "    def __init__(self, n_embed, num_experts, top_k):\n",
        "        super(SparseMoE, self).__init__()\n",
        "        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)\n",
        "        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])\n",
        "        self.top_k = top_k\n",
        "\n",
        "    def forward(self, x):\n",
        "        gating_output, indices = self.router(x)\n",
        "        final_output = torch.zeros_like(x)\n",
        "\n",
        "        # Reshape inputs for batch processing\n",
        "        flat_x = x.view(-1, x.size(-1))\n",
        "        flat_gating_output = gating_output.view(-1, gating_output.size(-1))\n",
        "\n",
        "        # Process each expert in parallel\n",
        "        for i, expert in enumerate(self.experts):\n",
        "            # Create a mask for the inputs where the current expert is in top-k\n",
        "            expert_mask = (indices == i).any(dim=-1)\n",
        "            flat_mask = expert_mask.view(-1)\n",
        "\n",
        "            if flat_mask.any():\n",
        "                expert_input = flat_x[flat_mask]\n",
        "                expert_output = expert(expert_input)\n",
        "\n",
        "                # Extract and apply gating scores\n",
        "                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)\n",
        "                weighted_output = expert_output * gating_scores\n",
        "\n",
        "                # Update final output additively by indexing and adding\n",
        "                final_output[expert_mask] += weighted_output.squeeze(1)\n",
        "\n",
        "        return final_output\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "06239630-0a1c-47c9-976c-7770f3d82e18",
          "showTitle": false,
          "title": ""
        },
        "id": "q8kDLI1ukRJj",
        "outputId": "69e7b330-cd47-4094-99c2-7928f8b60d0d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Shape of the final output: torch.Size([4, 8, 16])\n",
            "tensor([[[-0.0124, -0.0000, -0.0108, -0.0143, -0.0194,  0.0898, -0.0572,\n",
            "          -0.0453,  0.0351, -0.0658,  0.0853, -0.1535,  0.1538,  0.0340,\n",
            "          -0.0605, -0.0227],\n",
            "         [-0.1214,  0.0766, -0.0763, -0.0457,  0.0166, -0.0268, -0.0185,\n",
            "           0.1372, -0.1202, -0.1757,  0.1136,  0.0000, -0.2136, -0.2425,\n",
            "           0.0288,  0.0558],\n",
            "         [-0.0000,  0.2654,  0.0843,  0.0458,  0.2352,  0.2606, -0.1340,\n",
            "          -0.0359,  0.0971, -0.0165, -0.0690,  0.0838, -0.2352, -0.0203,\n",
            "           0.3269, -0.1147],\n",
            "         [-0.0240, -0.0332,  0.1032,  0.0554,  0.0276, -0.0000,  0.0204,\n",
            "          -0.0719,  0.1320, -0.1036,  0.0562,  0.0419,  0.0832, -0.2330,\n",
            "          -0.0000, -0.0260],\n",
            "         [-0.1457,  0.2298,  0.2523,  0.3260,  0.0813, -0.1724,  0.2074,\n",
            "          -0.0335,  0.1967, -0.0120, -0.0000,  0.0296, -0.0000,  0.2665,\n",
            "           0.0430,  0.1385],\n",
            "         [-0.0000, -0.0434,  0.1454, -0.0459,  0.0238, -0.0000, -0.0702,\n",
            "          -0.0082, -0.0920, -0.0000, -0.0085,  0.0000, -0.2530, -0.4184,\n",
            "          -0.1274,  0.0950],\n",
            "         [-0.0546,  0.0944,  0.2003,  0.1041,  0.0753, -0.0000,  0.0733,\n",
            "          -0.0000,  0.0678,  0.0014, -0.0723, -0.0742,  0.0435, -0.0637,\n",
            "           0.0122,  0.0000],\n",
            "         [ 0.0446,  0.0268, -0.0366,  0.1081, -0.0579, -0.1137,  0.0549,\n",
            "           0.0162,  0.0200, -0.0445, -0.0532,  0.1075, -0.0230,  0.1296,\n",
            "          -0.1638,  0.0537]],\n",
            "\n",
            "        [[-0.0390, -0.0269,  0.1131, -0.0722,  0.1846, -0.0380, -0.1055,\n",
            "          -0.0782,  0.1396,  0.0696, -0.0958,  0.0161, -0.0769, -0.1703,\n",
            "          -0.0732, -0.0180],\n",
            "         [-0.1349,  0.2449, -0.1247, -0.0149,  0.0774, -0.0732,  0.0755,\n",
            "          -0.0000,  0.0900, -0.1598, -0.1198, -0.0007, -0.0159, -0.2708,\n",
            "          -0.0636, -0.0289],\n",
            "         [ 0.0368,  0.0000, -0.1293, -0.0336,  0.0515, -0.0790,  0.0472,\n",
            "          -0.0830,  0.0000, -0.0521, -0.1743, -0.0205,  0.0320, -0.1011,\n",
            "           0.0055, -0.0228],\n",
            "         [ 0.0993, -0.0521,  0.2786, -0.0304, -0.0000,  0.1973, -0.0000,\n",
            "           0.0400,  0.0748, -0.1042, -0.3078,  0.0385, -0.2545,  0.3172,\n",
            "          -0.3621, -0.0708],\n",
            "         [ 0.0000,  0.1344,  0.0696, -0.2714, -0.1912,  0.0044, -0.1503,\n",
            "          -0.0262,  0.0000, -0.0136, -0.0329, -0.4539,  0.0990,  0.2285,\n",
            "          -0.3197,  0.0112],\n",
            "         [-0.1248,  0.1747, -0.1317, -0.1361, -0.0093, -0.0505,  0.0239,\n",
            "          -0.0009, -0.1792,  0.0079,  0.1453, -0.1140, -0.2461, -0.0000,\n",
            "           0.1578,  0.1527],\n",
            "         [-0.2012,  0.2509, -0.0933, -0.0000,  0.1976, -0.1527, -0.0379,\n",
            "           0.3109, -0.2121,  0.0012,  0.3155,  0.1832, -0.0000, -0.4221,\n",
            "           0.2019,  0.1104],\n",
            "         [-0.1177, -0.1666, -0.0471, -0.0000,  0.0864,  0.1395, -0.1160,\n",
            "           0.0994, -0.0007,  0.0159,  0.1186,  0.0068, -0.1566, -0.1419,\n",
            "          -0.0000, -0.1168]],\n",
            "\n",
            "        [[-0.1889, -0.1712, -0.0252,  0.0000,  0.1354,  0.0078, -0.3176,\n",
            "          -0.0588,  0.0764,  0.1522, -0.1602,  0.2140, -0.1569, -0.0161,\n",
            "           0.0060, -0.2089],\n",
            "         [-0.0592, -0.1243, -0.0000, -0.1623,  0.0841, -0.0263, -0.0239,\n",
            "           0.2556,  0.2700,  0.0405,  0.0842, -0.0692, -0.0505,  0.0000,\n",
            "          -0.4926,  0.1858],\n",
            "         [ 0.0551, -0.0318,  0.0195, -0.0098, -0.0351, -0.0271,  0.0646,\n",
            "           0.0014,  0.0169, -0.0162, -0.0135,  0.0070,  0.0116, -0.0627,\n",
            "          -0.0154,  0.0733],\n",
            "         [-0.1270,  0.3063, -0.1849,  0.0446, -0.0512,  0.0783, -0.0440,\n",
            "           0.1640,  0.1924, -0.1780,  0.0000, -0.0520, -0.0187, -0.0196,\n",
            "           0.3640, -0.0000],\n",
            "         [ 0.0000,  0.0442,  0.1172,  0.0441, -0.0102,  0.0660,  0.0529,\n",
            "          -0.1618,  0.1265,  0.0812, -0.1357, -0.0991, -0.0341,  0.1885,\n",
            "          -0.0000, -0.0753],\n",
            "         [ 0.0524,  0.2439, -0.1062, -0.1822,  0.1625, -0.0000,  0.2078,\n",
            "           0.1900,  0.1585, -0.1435,  0.1644,  0.0843,  0.0162,  0.0000,\n",
            "           0.0924,  0.1394],\n",
            "         [-0.0000, -0.0693,  0.0000, -0.1409, -0.0603,  0.0282,  0.1201,\n",
            "          -0.2591, -0.1273, -0.1428,  0.0000, -0.0790,  0.0134, -0.1038,\n",
            "           0.0461,  0.1344],\n",
            "         [-0.1980, -0.0087,  0.0991, -0.0000, -0.2497, -0.1385, -0.0985,\n",
            "           0.0974, -0.4163, -0.1940, -0.1078,  0.0458, -0.2890, -0.2785,\n",
            "          -0.1109,  0.1543]],\n",
            "\n",
            "        [[-0.0324,  0.0557, -0.0000, -0.1034,  0.1203,  0.0427, -0.0242,\n",
            "           0.0000, -0.1433, -0.0120,  0.1024, -0.0855, -0.0474, -0.0702,\n",
            "           0.0556,  0.0727],\n",
            "         [ 0.1194, -0.0000,  0.0720, -0.0555, -0.1901, -0.4236,  0.0438,\n",
            "           0.1632,  0.1087,  0.0192, -0.0385,  0.0332, -0.0000,  0.0847,\n",
            "           0.1672,  0.1100],\n",
            "         [-0.0620, -0.0724,  0.0220, -0.0000, -0.0757, -0.0936,  0.0000,\n",
            "           0.0757, -0.0122, -0.1435, -0.0289,  0.0000, -0.0816, -0.0648,\n",
            "          -0.0638,  0.0208],\n",
            "         [-0.1030,  0.1554,  0.1701,  0.2989,  0.0000, -0.0722,  0.1687,\n",
            "          -0.1795, -0.0000, -0.0000, -0.0096, -0.2642, -0.2379,  0.1062,\n",
            "          -0.1318,  0.3303],\n",
            "         [-0.0768,  0.1000,  0.0917, -0.0479, -0.0475,  0.0435,  0.1054,\n",
            "           0.0197,  0.1609, -0.0499, -0.2033,  0.0000,  0.0201,  0.1276,\n",
            "          -0.0000, -0.0724],\n",
            "         [ 0.0046,  0.0318,  0.0000, -0.2774, -0.0000, -0.1755, -0.0641,\n",
            "          -0.0596,  0.0422,  0.0099,  0.0066, -0.1313, -0.0019,  0.1601,\n",
            "          -0.1242,  0.0000],\n",
            "         [-0.0234,  0.3237,  0.0500,  0.0302, -0.1091, -0.0000,  0.0154,\n",
            "          -0.0177, -0.0602, -0.4930, -0.1593, -0.1109, -0.0000,  0.1590,\n",
            "           0.0000, -0.1682],\n",
            "         [-0.1681, -0.1363,  0.0161, -0.0556,  0.0000, -0.1762, -0.0621,\n",
            "           0.0830, -0.0099, -0.0000,  0.0494, -0.0318, -0.0597,  0.0600,\n",
            "           0.0852,  0.0050]]], grad_fn=<ViewBackward0>)\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "#Let's test this out\n",
        "num_experts = 8\n",
        "top_k = 2\n",
        "n_embd = 16\n",
        "dropout=0.1\n",
        "\n",
        "mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output\n",
        "sparse_moe = SparseMoE(n_embd, num_experts, top_k)\n",
        "final_output = sparse_moe(mh_output)\n",
        "print(\"Shape of the final output:\", final_output.shape)\n",
        "print(final_output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "7476ca07-a315-4108-aa77-46173a703ca2",
          "showTitle": false,
          "title": ""
        },
        "id": "l7GoxCi2kRJj"
      },
      "source": [
        "To emphasize, it's important to recognize that the magnitudes of the top_k experts output from the Router/ gating network, as illustrated in the code above, are also significant. These top_k indices identify the experts that are activated, and the magnitude of the values in those top_k dimensions determines their respective weighting. This concept of weighted summation is further highlighted in the diagram below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "d99b5dce-301e-4380-8263-b5cfb4136ab2",
          "showTitle": false,
          "title": ""
        },
        "id": "e9a_oQ2akRJj"
      },
      "source": [
        "![sparse MoE](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/sparseMoEfinal.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "da5f3be4-f155-4d6c-bcbd-2f88a088261f",
          "showTitle": false,
          "title": ""
        },
        "id": "vMZCda7dkRJj"
      },
      "source": [
        "### Putting it all together"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "0eaf71cd-c77e-40c7-b5be-e364e91685cf",
          "showTitle": false,
          "title": ""
        },
        "id": "f8yczkFHkRJj"
      },
      "outputs": [],
      "source": [
        "#First defining hyperparameters and boiler plate code. Imports and data preparation code is repeated for convenience\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "from torch.nn import init\n",
        "\n",
        "# hyperparameters\n",
        "batch_size = 16 # how many independent sequences will we process in parallel?\n",
        "block_size = 32 # what is the maximum context length for predictions?\n",
        "max_iters = 5000\n",
        "eval_interval = 100\n",
        "learning_rate = 1e-3\n",
        "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "eval_iters = 400\n",
        "head_size = 16\n",
        "n_embed = 128\n",
        "n_head = 8\n",
        "n_layer = 8\n",
        "dropout = 0.1\n",
        "num_experts = 8\n",
        "top_k = 2\n",
        "# ------------\n",
        "\n",
        "torch.manual_seed(1337)\n",
        "\n",
        "with open('input.txt', 'r', encoding='utf-8') as f:\n",
        "    text = f.read()\n",
        "\n",
        "# here are all the unique characters that occur in this text\n",
        "chars = sorted(list(set(text)))\n",
        "vocab_size = len(chars)\n",
        "# create a mapping from characters to integers\n",
        "stoi = { ch:i for i,ch in enumerate(chars) }\n",
        "itos = { i:ch for i,ch in enumerate(chars) }\n",
        "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
        "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
        "\n",
        "# Train and test splits\n",
        "data = torch.tensor(encode(text), dtype=torch.long)\n",
        "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
        "train_data = data[:n]\n",
        "val_data = data[n:]\n",
        "\n",
        "# data loading\n",
        "def get_batch(split):\n",
        "    # generate a small batch of data of inputs x and targets y\n",
        "    data = train_data if split == 'train' else val_data\n",
        "    ix = torch.randint(len(data) - block_size, (batch_size,))\n",
        "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
        "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
        "    x, y = x.to(device), y.to(device)\n",
        "    return x, y\n",
        "\n",
        "@torch.no_grad()\n",
        "def estimate_loss():\n",
        "    out = {}\n",
        "    model.eval()\n",
        "    for split in ['train', 'val']:\n",
        "        losses = torch.zeros(eval_iters)\n",
        "        for k in range(eval_iters):\n",
        "            X, Y = get_batch(split)\n",
        "            logits, loss = model(X, Y)\n",
        "            losses[k] = loss.item()\n",
        "        out[split] = losses.mean()\n",
        "    model.train()\n",
        "    return out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "ee1180f7-5004-4425-87fe-9a81a17b9024",
          "showTitle": false,
          "title": ""
        },
        "id": "QfxJ6B2fkRJj"
      },
      "outputs": [],
      "source": [
        "class Head(nn.Module):\n",
        "    \"\"\" one head of self-attention \"\"\"\n",
        "\n",
        "    def __init__(self, head_size):\n",
        "        super().__init__()\n",
        "        self.key = nn.Linear(n_embed, head_size, bias=False)\n",
        "        self.query = nn.Linear(n_embed, head_size, bias=False)\n",
        "        self.value = nn.Linear(n_embed, head_size, bias=False)\n",
        "        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
        "\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        B,T,C = x.shape\n",
        "        k = self.key(x)   # (B,T,C)\n",
        "        q = self.query(x) # (B,T,C)\n",
        "        # compute attention scores (\"affinities\")\n",
        "        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
        "        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
        "        wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
        "        wei = self.dropout(wei)\n",
        "        # perform the weighted aggregation of the values\n",
        "        v = self.value(x) # (B,T,C)\n",
        "        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
        "        return out\n",
        "\n",
        "#Multi-Headed Self Attention\n",
        "class MultiHeadAttention(nn.Module):\n",
        "    \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
        "\n",
        "    def __init__(self, num_heads, head_size):\n",
        "        super().__init__()\n",
        "        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
        "        self.proj = nn.Linear(n_embed, n_embed)\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "\n",
        "    def forward(self, x):\n",
        "        out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
        "        out = self.dropout(self.proj(out))\n",
        "        return out\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "03611a92-aaa2-4e0d-9755-cba56f96c794",
          "showTitle": false,
          "title": ""
        },
        "id": "y35jVCZYkRJk"
      },
      "outputs": [],
      "source": [
        "#Expert module\n",
        "class Expert(nn.Module):\n",
        "    \"\"\" An MLP is a simple linear layer followed by a non-linearity i.e. each Expert \"\"\"\n",
        "\n",
        "    def __init__(self, n_embed):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(n_embed, 4 * n_embed),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(4 * n_embed, n_embed),\n",
        "            nn.Dropout(dropout),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.net(x)\n",
        "\n",
        "#noisy top-k gating\n",
        "class NoisyTopkRouter(nn.Module):\n",
        "    def __init__(self, n_embed, num_experts, top_k):\n",
        "        super(NoisyTopkRouter, self).__init__()\n",
        "        self.top_k = top_k\n",
        "        #layer for router logits\n",
        "        self.topkroute_linear = nn.Linear(n_embed, num_experts)\n",
        "        self.noise_linear =nn.Linear(n_embed, num_experts)\n",
        "\n",
        "\n",
        "    def forward(self, mh_output):\n",
        "        # mh_ouput is the output tensor from multihead self attention block\n",
        "        logits = self.topkroute_linear(mh_output)\n",
        "\n",
        "        #Noise logits\n",
        "        noise_logits = self.noise_linear(mh_output)\n",
        "\n",
        "        #Adding scaled unit gaussian noise to the logits\n",
        "        noise = torch.randn_like(logits)*F.softplus(noise_logits)\n",
        "        noisy_logits = logits + noise\n",
        "\n",
        "        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)\n",
        "        zeros = torch.full_like(noisy_logits, float('-inf'))\n",
        "        sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
        "        router_output = F.softmax(sparse_logits, dim=-1)\n",
        "        return router_output, indices\n",
        "\n",
        "#Now create the sparse mixture of experts module\n",
        "class SparseMoE(nn.Module):\n",
        "    def __init__(self, n_embed, num_experts, top_k):\n",
        "        super(SparseMoE, self).__init__()\n",
        "        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)\n",
        "        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])\n",
        "        self.top_k = top_k\n",
        "\n",
        "    def forward(self, x):\n",
        "        gating_output, indices = self.router(x)\n",
        "        final_output = torch.zeros_like(x)\n",
        "\n",
        "        # Reshape inputs for batch processing\n",
        "        flat_x = x.view(-1, x.size(-1))\n",
        "        flat_gating_output = gating_output.view(-1, gating_output.size(-1))\n",
        "\n",
        "        # Process each expert in parallel\n",
        "        for i, expert in enumerate(self.experts):\n",
        "            # Create a mask for the inputs where the current expert is in top-k\n",
        "            expert_mask = (indices == i).any(dim=-1)\n",
        "            flat_mask = expert_mask.view(-1)\n",
        "\n",
        "            if flat_mask.any():\n",
        "                expert_input = flat_x[flat_mask]\n",
        "                expert_output = expert(expert_input)\n",
        "\n",
        "                # Extract and apply gating scores\n",
        "                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)\n",
        "                weighted_output = expert_output * gating_scores\n",
        "\n",
        "                # Update final output additively by indexing and adding\n",
        "                final_output[expert_mask] += weighted_output.squeeze(1)\n",
        "\n",
        "        return final_output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "bfdff2bb-092f-41c8-9a33-c84e6f8d6633",
          "showTitle": false,
          "title": ""
        },
        "id": "jGiuRsSgkRJk"
      },
      "outputs": [],
      "source": [
        "#First create a self attention + mixture of experts block, that may be repeated several number of times\n",
        "#Copy pasting key architecture variables for clarity\n",
        "\n",
        "class Block(nn.Module):\n",
        "    \"\"\" Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) \"\"\"\n",
        "\n",
        "    def __init__(self, n_embed, n_head, num_experts, top_k):\n",
        "        # n_embed: embedding dimension, n_head: the number of heads we'd like\n",
        "        super().__init__()\n",
        "        head_size = n_embed // n_head\n",
        "        self.sa = MultiHeadAttention(n_head, head_size)\n",
        "        self.smoe = SparseMoE(n_embed, num_experts, top_k)\n",
        "        self.ln1 = nn.LayerNorm(n_embed)\n",
        "        self.ln2 = nn.LayerNorm(n_embed)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = x + self.sa(self.ln1(x))\n",
        "        x = x + self.smoe(self.ln2(x))\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "2d32a276-d0cc-4808-90d7-62441771af44",
          "showTitle": false,
          "title": ""
        },
        "id": "RpyZBA71kRJk"
      },
      "outputs": [],
      "source": [
        "#Finally putting it all together to crease a sparse mixture of experts language model\n",
        "class SparseMoELanguageModel(nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        # each token directly reads off the logits for the next token from a lookup table\n",
        "        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)\n",
        "        self.position_embedding_table = nn.Embedding(block_size, n_embed)\n",
        "        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])\n",
        "        self.ln_f = nn.LayerNorm(n_embed) # final layer norm\n",
        "        self.lm_head = nn.Linear(n_embed, vocab_size)\n",
        "\n",
        "    def forward(self, idx, targets=None):\n",
        "        B, T = idx.shape\n",
        "\n",
        "        # idx and targets are both (B,T) tensor of integers\n",
        "        tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
        "        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
        "        x = tok_emb + pos_emb # (B,T,C)\n",
        "        x = self.blocks(x) # (B,T,C)\n",
        "        x = self.ln_f(x) # (B,T,C)\n",
        "        logits = self.lm_head(x) # (B,T,vocab_size)\n",
        "\n",
        "        if targets is None:\n",
        "            loss = None\n",
        "        else:\n",
        "            B, T, C = logits.shape\n",
        "            logits = logits.view(B*T, C)\n",
        "            targets = targets.view(B*T)\n",
        "            loss = F.cross_entropy(logits, targets)\n",
        "\n",
        "        return logits, loss\n",
        "\n",
        "    def generate(self, idx, max_new_tokens):\n",
        "        # idx is (B, T) array of indices in the current context\n",
        "        for _ in range(max_new_tokens):\n",
        "            # crop idx to the last block_size tokens\n",
        "            idx_cond = idx[:, -block_size:]\n",
        "            # get the predictions\n",
        "            logits, loss = self(idx_cond)\n",
        "            # focus only on the last time step\n",
        "            logits = logits[:, -1, :] # becomes (B, C)\n",
        "            # apply softmax to get probabilities\n",
        "            probs = F.softmax(logits, dim=-1) # (B, C)\n",
        "            # sample from the distribution\n",
        "            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
        "            # append sampled index to the running sequence\n",
        "            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
        "        return idx"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "622ba7ce-3f20-4820-8982-93f3d3b7be09",
          "showTitle": false,
          "title": ""
        },
        "id": "bBzQXmfykRJk"
      },
      "source": [
        "Kaiming He initialization is used here because of presence of ReLU activations in the experts. Feel free to experiment with Glorot initialization which is more commonly used in transformers. Jeremy Howard's Fastai Part 2 has an excellent lecture that implements these from scratch: https://course.fast.ai/Lessons/lesson17.html"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "a6d3c057-08ee-4c1b-8013-6a88b2eadac5",
          "showTitle": false,
          "title": ""
        },
        "id": "guGaJqHbkRJk"
      },
      "outputs": [],
      "source": [
        "\n",
        "def kaiming_init_weights(m):\n",
        "    if isinstance (m, (nn.Linear)):\n",
        "        init.kaiming_normal_(m.weight)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "5b4d9525-8405-4a51-adda-661aba004e57",
          "showTitle": false,
          "title": ""
        },
        "id": "nJGGkXz4kRJl",
        "outputId": "8518ec23-caa0-4167-88c6-65a7c905743a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "SparseMoELanguageModel(\n",
              "  (token_embedding_table): Embedding(65, 128)\n",
              "  (position_embedding_table): Embedding(32, 128)\n",
              "  (blocks): Sequential(\n",
              "    (0): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (1): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (2): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (3): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (4): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (5): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (6): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "    (7): Block(\n",
              "      (sa): MultiHeadAttention(\n",
              "        (heads): ModuleList(\n",
              "          (0-7): 8 x Head(\n",
              "            (key): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (query): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (value): Linear(in_features=128, out_features=16, bias=False)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (proj): Linear(in_features=128, out_features=128, bias=True)\n",
              "        (dropout): Dropout(p=0.1, inplace=False)\n",
              "      )\n",
              "      (smoe): SparseMoE(\n",
              "        (router): NoisyTopkRouter(\n",
              "          (topkroute_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "          (noise_linear): Linear(in_features=128, out_features=8, bias=True)\n",
              "        )\n",
              "        (experts): ModuleList(\n",
              "          (0-7): 8 x Expert(\n",
              "            (net): Sequential(\n",
              "              (0): Linear(in_features=128, out_features=512, bias=True)\n",
              "              (1): ReLU()\n",
              "              (2): Linear(in_features=512, out_features=128, bias=True)\n",
              "              (3): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "    )\n",
              "  )\n",
              "  (ln_f): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
              "  (lm_head): Linear(in_features=128, out_features=65, bias=True)\n",
              ")"
            ]
          },
          "execution_count": 38,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model = SparseMoELanguageModel()\n",
        "model.apply(kaiming_init_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "6adf1d04-e668-4d14-b691-161ea4e4dccf",
          "showTitle": false,
          "title": ""
        },
        "id": "_cb1-L8ckRJl"
      },
      "source": [
        "I have used mlflow to track and log the metrics I care about and the training hyperparameters. The training loop in the next cell includes this mlflow code. If you prefer to just train without using mlflow, the subsequent cell has code without the mlflow code. However, I find it very convenient to track parameters and metrics, particularly when experimenting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "b8968247-0d7b-4460-b96b-06743b31c55d",
          "showTitle": false,
          "title": ""
        },
        "id": "WTG1Fv4SkRJl",
        "outputId": "38318015-63a9-4959-cfe0-cab305625f49"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "8.996545 M parameters\n",
            "step 0: train loss 5.3223, val loss 5.3166\n",
            "step 100: train loss 2.7351, val loss 2.7429\n",
            "step 200: train loss 2.5125, val loss 2.5233\n",
            "step 300: train loss 2.4239, val loss 2.4384\n",
            "step 400: train loss 2.3477, val loss 2.3656\n",
            "step 500: train loss 2.2743, val loss 2.3040\n",
            "step 600: train loss 2.2087, val loss 2.2199\n",
            "step 700: train loss 2.1461, val loss 2.1853\n",
            "step 800: train loss 2.1018, val loss 2.1418\n",
            "step 900: train loss 2.0592, val loss 2.1055\n",
            "step 1000: train loss 2.0138, val loss 2.0822\n",
            "step 1100: train loss 1.9820, val loss 2.0486\n",
            "step 1200: train loss 1.9536, val loss 2.0383\n",
            "step 1300: train loss 1.9218, val loss 2.0070\n",
            "step 1400: train loss 1.9077, val loss 2.0059\n",
            "step 1500: train loss 1.8831, val loss 1.9784\n",
            "step 1600: train loss 1.8527, val loss 1.9740\n",
            "step 1700: train loss 1.8309, val loss 1.9468\n",
            "step 1800: train loss 1.8157, val loss 1.9401\n",
            "step 1900: train loss 1.7953, val loss 1.9243\n",
            "step 2000: train loss 1.7853, val loss 1.9158\n",
            "step 2100: train loss 1.7746, val loss 1.8969\n",
            "step 2200: train loss 1.7527, val loss 1.8901\n",
            "step 2300: train loss 1.7433, val loss 1.8770\n",
            "step 2400: train loss 1.7327, val loss 1.8827\n",
            "step 2500: train loss 1.7183, val loss 1.8713\n",
            "step 2600: train loss 1.7155, val loss 1.8592\n",
            "step 2700: train loss 1.7089, val loss 1.8564\n",
            "step 2800: train loss 1.6978, val loss 1.8500\n",
            "step 2900: train loss 1.6902, val loss 1.8340\n",
            "step 3000: train loss 1.6854, val loss 1.8398\n",
            "step 3100: train loss 1.6680, val loss 1.8313\n",
            "step 3200: train loss 1.6666, val loss 1.8225\n",
            "step 3300: train loss 1.6508, val loss 1.8268\n",
            "step 3400: train loss 1.6499, val loss 1.8148\n",
            "step 3500: train loss 1.6438, val loss 1.8068\n",
            "step 3600: train loss 1.6317, val loss 1.7923\n",
            "step 3700: train loss 1.6314, val loss 1.7856\n",
            "step 3800: train loss 1.6260, val loss 1.7862\n",
            "step 3900: train loss 1.6092, val loss 1.7757\n",
            "step 4000: train loss 1.6116, val loss 1.7753\n",
            "step 4100: train loss 1.6024, val loss 1.7823\n",
            "step 4200: train loss 1.6067, val loss 1.7650\n",
            "step 4300: train loss 1.5895, val loss 1.7629\n",
            "step 4400: train loss 1.5940, val loss 1.7623\n",
            "step 4500: train loss 1.5937, val loss 1.7506\n",
            "step 4600: train loss 1.5941, val loss 1.7743\n",
            "step 4700: train loss 1.5787, val loss 1.7646\n",
            "step 4800: train loss 1.5786, val loss 1.7585\n",
            "step 4900: train loss 1.5719, val loss 1.7439\n",
            "step 4999: train loss 1.5712, val loss 1.7508\n"
          ]
        }
      ],
      "source": [
        "#Using MLFlow\n",
        "m = model.to(device)\n",
        "# print the number of parameters in the model\n",
        "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
        "\n",
        "# create a PyTorch optimizer\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
        "#mlflow.set_experiment(\"makeMoE\")\n",
        "with mlflow.start_run():\n",
        "    #If you use mlflow.autolog() this will be automatically logged. I chose to explicitly log here for completeness\n",
        "    params = {\"batch_size\": batch_size , \"block_size\" : block_size, \"max_iters\": max_iters, \"eval_interval\": eval_interval,\n",
        "              \"learning_rate\": learning_rate, \"device\": device, \"eval_iters\": eval_iters, \"dropout\" : dropout, \"num_experts\": num_experts, \"top_k\": top_k }\n",
        "    mlflow.log_params(params)\n",
        "    for iter in range(max_iters):\n",
        "\n",
        "        # every once in a while evaluate the loss on train and val sets\n",
        "        if iter % eval_interval == 0 or iter == max_iters - 1:\n",
        "            losses = estimate_loss()\n",
        "            print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
        "            metrics = {\"train_loss\": float(losses['train']), \"val_loss\": float(losses['val'])}\n",
        "            mlflow.log_metrics(metrics, step=iter)\n",
        "\n",
        "\n",
        "        # sample a batch of data\n",
        "        xb, yb = get_batch('train')\n",
        "\n",
        "        # evaluate the loss\n",
        "        logits, loss = model(xb, yb)\n",
        "        optimizer.zero_grad(set_to_none=True)\n",
        "        loss.backward()\n",
        "        optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "1ed96085-c292-4624-a2cc-be8aad38df79",
          "showTitle": false,
          "title": ""
        },
        "id": "jFBVfgfekRJl"
      },
      "source": [
        "Logging train and validation losses gives you a good indication of how the training is going. The plot shows that I probably should have stopped around 4500 steps (when the validation loss jumps up a bit)\n",
        "\n",
        "![mlflow_dash](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/mlflow_dash.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {},
          "inputWidgets": {},
          "nuid": "6360e1b7-94c4-4ef1-a850-9bc93f49a083",
          "showTitle": false,
          "title": ""
        },
        "id": "WP_2lcRUkRJl"
      },
      "outputs": [],
      "source": [
        "#Not using MLflow\n",
        "m = model.to(device)\n",
        "# print the number of parameters in the model\n",
        "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
        "\n",
        "# create a PyTorch optimizer\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
        "\n",
        "for iter in range(max_iters):\n",
        "\n",
        "    # every once in a while evaluate the loss on train and val sets\n",
        "    if iter % eval_interval == 0 or iter == max_iters - 1:\n",
        "        losses = estimate_loss()\n",
        "        print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
        "\n",
        "    # sample a batch of data\n",
        "    xb, yb = get_batch('train')\n",
        "\n",
        "    # evaluate the loss\n",
        "    logits, loss = model(xb, yb)\n",
        "    optimizer.zero_grad(set_to_none=True)\n",
        "    loss.backward()\n",
        "    optimizer.step()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "application/vnd.databricks.v1+cell": {
          "cellMetadata": {
            "byteLimit": 2048000,
            "rowLimit": 10000
          },
          "inputWidgets": {},
          "nuid": "8aa6e4c4-c688-4985-a3b8-e2af1f771e54",
          "showTitle": false,
          "title": ""
        },
        "id": "W4yshpXMkRJl",
        "outputId": "e69af1c7-2e2f-475c-eeba-b20dec8532c2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "And, rem?\n",
            "He say, the soul froom, from Merver:\n",
            "And cirture muck on this part son Turn the will:\n",
            "The fulist somet on glace: O they were\n",
            "Le I scide, on thouch scaribe hoptant would\n",
            "And chnot did the croblarious too contle your life,,\n",
            "Some lo you her, alas forfour-porth, it, see!\n",
            "What, when is a strong of though first.\n",
            "\n",
            "DUKE VINCENVENTIO:\n",
            "If it ever fecond he town sue kigh now,\n",
            "That thou wold'st is steen 't.\n",
            "\n",
            "SIMNA:\n",
            "Angent her; no, my a born Yorthort,\n",
            "Romeoos soun and lawf to your sawe with ch a woft ttastly defy,\n",
            "To declay the soul art; and meart smad.\n",
            "\n",
            "CORPIOLLANUS:\n",
            "Which I cannot shall do from by born und ot cold warrike,\n",
            "What king we best anone wrave's going of heard and good\n",
            "Thus playvage; you have wold the grace.\n",
            "\n",
            "KING EDWARD:\n",
            "Daughtia I they honour of your king thissand yish, there Marcius has found Romeo\n",
            "And Havaulint cound is the and such shall diake her forture\n",
            "The rights:' chy Villona, in I thenruns!\n",
            "\n",
            "What you, in posinn the dayms bore,\n",
            "As recompe sontts stand, where thronough is thy gracieful misic?\n",
            "\n",
            "DUKE VINCENCENTIO:\n",
            "When I fangly wilt dovines ear ach surch to mearts treck impost not,\n",
            "Lord, go my abeging thus, igainst onlove's incruent!\n",
            "\n",
            "Pum, God GlouNT:\n",
            "Why, to this unclomer bed: who we priol,\n",
            "Farewell eye! first me, thosto I are the affair yea, to not, when\n",
            "on you his give magetine.\n",
            "\n",
            "DUKE VINCENTIO:\n",
            "I have madast true; he hope ever thrubt a got to crave.\n",
            "\n",
            "ClONTISTINGNUS:\n",
            "TYet ago, cumbot to down withat these was the connoure.\n",
            "This, when your elses: deaths, Henry, thrink and of mennine effirm'd, 'tis littly Beardings agay\n",
            "Teld, are is steek make rough, in\n",
            "having givant and of thy that jucing an,\n",
            "But, over no some.\n",
            "\n",
            "WARWICK:\n",
            "I am last, cromend, Cablarlius;\n",
            "That everceigal whils ghaves nobled him ineace?\n",
            "\n",
            "LEONTES:\n",
            "May, It it say; was thou suit hast\n",
            "on yough. I am follow to the balk it;\n",
            "Sixcted Harce petchol by his doubard'd,\n",
            "I am chouncealos! a my life, and see me jeblearry, be with fliful so son!\n",
            "\n",
            "LEONTES:\n",
            "How wus have up frompriton!\n",
            "\n",
            "BRUTUS:\n",
            "To and than \n"
          ]
        }
      ],
      "source": [
        "# generate from the model. Not great. Not too bad either\n",
        "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
        "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))"
      ]
    }
  ],
  "metadata": {
    "application/vnd.databricks.v1+notebook": {
      "dashboards": [],
      "language": "python",
      "notebookMetadata": {
        "pythonIndentUnit": 4
      },
      "notebookName": "makeMoE_from_Scratch",
      "widgets": {}
    },
    "colab": {
      "include_colab_link": true,
      "provenance": []
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
